# -*- coding: utf-8 -*-
"""
Created on Wed Jan  6 21:17:39 2021

@author: 98330
"""
import gurobipy as gr
import numpy as np
import root_const
from base_utils.my_utils import load_json, save_json
from root_const import ROOT_PATH
from server_scheduling.server_switch_basic import ServerSwitchBasic
from min_cost_flow.min_cost_flow import GetAccurateFlow

class ServerSwitchEnergy(ServerSwitchBasic):
    
    def __init__(self, location_array, k_near, k_relative, lanti1, long1, server_placement, time_interval,
                 lanti_history_list, long_history_list, frequ_history_list, adjust_rate_history,
                 lanti_list, long_list, frequ_list,adjust_rate):
        """
        
        param
        ----------
            location_array : array
                the array to store both the lantitude and longtitude information of the location
                structure : nx2
                content : first column ---> lantitude , second column ---> longitude
            k_near : int
                the number of the near base station we will find
            compute_num : int
                the number of the servers we want to assign
            workload : list
                store the information about the load
        """
        ServerSwitchBasic.__init__(self, location_array, k_near)                  
        self.lanti1 = lanti1
        self.long1 = long1
        self.server_placement = server_placement
        self.time_interval = time_interval
        self.lanti_list = lanti_list
        self.long_list = long_list
        self.frequ_list = frequ_list
        self.lanti_history_list = lanti_history_list
        self.long_history_list = long_history_list
        self.frequ_history_list = frequ_history_list
        self.adjust_rate = adjust_rate
        self.adjust_rate_history = adjust_rate_history
        self.k_relative = k_relative
        self.optimal_list = []
        
    def find_same_receiver(self, list1, list2):
        list3 = list(set(list1) & set(list2))
        return list3
    
    def find_same_receiver_union(self):
        list_near, list_far, list_assign = self.location_divide()
        base_num = len(list_near)
        all_result = []
        for i in range(base_num):
            if i % 100 == 0:
                print(i)
            all_result.append(dict())
            base_list = list_near[i]
            for j in range(base_num):
                same_result = self.find_same_receiver(base_list,list_near[j])
                for k in range(len(same_result)):
                    key = str(same_result[k])
                    if key in all_result[i].keys():
                        all_result[i][key] = all_result[i][key] + 1
                    else:
                        all_result[i][key] = 0
        path = ROOT_PATH + '/all_data/server_scheduling_result/server_union_statistic/'
        name = 'server_union.json'
        save_json(path, name, all_result)
        
    def find_relative_loca(self):
        union_relative = load_json(ROOT_PATH + '/all_data/server_scheduling_result/server_union_statistic/', 
                          'server_union.json')
        list_relative = []
        list_relative_assign = []
        for i in range(len(union_relative)):
            list_relative_assign.append([])    
        for i in range(len(union_relative)):
            dict_temp = union_relative[i]
            dict_sort = sorted(dict_temp.items(),key=lambda dict_temp:dict_temp[1],reverse=True)
            dict_key = [dict_sort[i][0] for i in range(len(dict_sort))]
            dict_key = list(map(int, dict_key))
            list_relative.append(dict_key[0:self.k_relative])
            for j in range(self.k_relative):
                list_relative_assign[list_relative[i][j]].append(i)
        return list_relative, list_relative_assign
        
    def load_distribution(self, lanti_1, long_1, lanti_2, long_2, load):
        """find the locations in the loca2 that are emerged in the loca1 
        
        param
        -----------
            lanti_1 : list
                store all the lantitude information of the loca1
            lanti_2 : list
                store all the lantitude information of the loca2
            long_1 : list
                store all the longtitude information of the loca1
            long_2 : list
                store all the longtitude information of the loca2
            load : list
                store all the load information of the loca2
            
        return
        ----------
            load_rebuild : list
                store the load in both loca1 and loca2, others are 0
            num : int
                the number of the loca included in both loca1 and loca2
        
        refer
        ----------
            Other func : None
        """
        base_num = len(lanti_1)
        num = 0
        load_rebuild = [0 for i in range(base_num)]
        
        for i in range(len(lanti_2)):    
            try:
                lanti_index = lanti_1.index(lanti_2[i])
            except:
                lanti_index = -1
        
            try:
                long_index = long_1.index(long_2[i])
            except:
                long_index = -2
                
            if lanti_index == long_index:
                load_rebuild[lanti_index] = load[i]
                num = num + 1
        return load_rebuild, num
    
    def sum_distribution(self):
        """
        rebuild the data to change one day data into several part
        """
        final_rebuild = []
        
        for i in range(len(self.lanti_list)):
            load_rebuild, num = self.load_distribution(self.lanti1, self.long1, 
                                                       self.lanti_list[i], self.long_list[i], self.frequ_list[i])
            final_rebuild.append(list(np.array(load_rebuild) * self.adjust_rate))
        return final_rebuild
    
    def sum_distribution_history(self):
        """
        rebuild the data to change one day data into several part
        """
        final_rebuild = []
        
        for i in range(len(self.lanti_list)):
            load_rebuild, num = self.load_distribution(self.lanti1, self.long1, self.lanti_history_list[i],
                                                       self.long_history_list[i], self.frequ_history_list[i])
            final_rebuild.append(list(np.array(load_rebuild) * self.adjust_rate_history))
        return final_rebuild
    
    def server_usage_predict(self):
      
        list_near, list_far, list_assign = self.location_divide()
#        list_near, list_assign = self.find_relative_loca()
        final_rebuild = self.sum_distribution_history()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        a = model_placement.addVars(time_part, vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        
        model_placement.setObjective(sum(a[i] for i in range(time_part)) , gr.GRB.MAXIMIZE)
        for i in range(time_part):
            model_placement.addConstrs(
                    y[i,j] <= self.server_placement[j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[k,i] 
                    for i in range(base_num)
             ) 
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            final_optimal = list_var[base_num * time_part : base_num * time_part + time_part]
        self.optimal_list = final_optimal.copy()

    
    def server_switch(self, max_rate, weight_switch, weight_await):
      
        list_near, list_far, list_assign = self.location_divide()
#        list_near, list_assign = self.find_relative_loca()
        final_rebuild = self.sum_distribution_history()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        self.server_usage_predict()
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        energy_cost = model_placement.addVar(vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        switch = model_placement.addVars(time_part - 1, base_num, vtype='S', lb=0)  
        a = 1 / max_rate
        a_list = []
        for i in range(time_part):
            if a <= self.optimal_list[i] - 0.2 * a:
                a_list.append(a)
            else:
                a_list.append(self.optimal_list[i] * 0.8)
        
        model_placement.setObjective(energy_cost , gr.GRB.MINIMIZE)
        for i in range(time_part):
            model_placement.addConstrs(
                    y[i,j] <= self.server_placement[j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a_list[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[k,i] 
                    for i in range(base_num)
             )
        for k in range(0,time_part-1):
            model_placement.addConstrs(
                    switch[k,j] >= y[k,j] - y[k+1,j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    switch[k,j] >= y[k+1,j] - y[k,j] for j in range(base_num)
            ) 
        
        model_placement.addConstr(
                 energy_cost == weight_switch * sum(switch[i,j] for j in range(base_num) for i in range(time_part - 1))
                 + weight_await * sum(y[i,j] for j in range(base_num) for i in range(time_part))
        ) 
           
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            result_list = np.array(list_var[0: base_num * time_part])
            result_list = result_list.reshape((time_part,base_num))
            final_optimal = list_var[base_num * time_part : base_num * time_part]
        else:
            print("no solve")
            return 0,0
        return result_list, final_optimal
    
    def server_usage(self, y_array):
      
        list_near, list_far, list_assign = self.location_divide()
#        list_near, list_assign = self.find_relative_loca()
        final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        a = model_placement.addVars(time_part, vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        
        model_placement.setObjective(sum(a[i] for i in range(time_part)) , gr.GRB.MAXIMIZE)
        for i in range(time_part):
            model_placement.addConstrs(
                    y[i,j] <= y_array[i,j] for j in range(base_num)
            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[k,i] 
                    for i in range(base_num)
             ) 
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            result_list = np.array(list_var[0: base_num * time_part])
            result_list = result_list.reshape((time_part,base_num))
            final_optimal = list_var[base_num * time_part : base_num * time_part + time_part]
            u_list = np.array(list_var[base_num * time_part + time_part : ])
            u_array = u_list.reshape((time_part, base_num, len(list_near[0])))
        return result_list, final_optimal, u_array, list_assign, final_rebuild, base_num, list_near

    def server_usage2(self, y_array):
      
        list_near, list_far, list_assign = self.location_divide()
#        list_near, list_assign = self.find_relative_loca()
        final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
        
        model_placement = gr.Model('Lip_placement')
#        y = model_placement.addVars(time_part, base_num, vtype='S', lb=0)
        a = model_placement.addVars(time_part, vtype='S', lb=0)  
        u = model_placement.addVars(time_part, base_num, len(list_near[0]), vtype='S', lb=0)                
        
        model_placement.setObjective(sum(a[i] for i in range(time_part)) , gr.GRB.MAXIMIZE)
        for i in range(time_part):
#            model_placement.addConstrs(
#                    y[i,j] <= y_array[i,j] for j in range(base_num)
#            ) 
            model_placement.addConstrs(
                    sum(u[i,k,j] for j in range(len(list_near[0]))) == a[i]
                    for k in range(base_num)
            )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y_array[k,i] 
                    for i in range(base_num)
             ) 
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
#            result_list = np.array(list_var[0: base_num * time_part])
            result_list = y_array
            final_optimal = list_var[0 : time_part]
            u_list = np.array(list_var[time_part : ])
            u_array = u_list.reshape((time_part, base_num, len(list_near[0])))
        return result_list, final_optimal, u_array, list_assign, final_rebuild, base_num, list_near

    def service_failure_statistic(self, y_array, max_rate):
        result_list, final_optimal, u_array, list_assign, final_rebuild, base_num, list_near = self.server_usage2(y_array)
        service_failed_num = []
        a = 1 / max_rate
        for k in range(len(final_optimal)):
            failed_num = 0
            if final_optimal[k] >= a:
                service_failed_num.append(failed_num)
                continue
            for i in range(base_num):
                service_temp = 0
                for j in range(len(list_assign[i])):
                    service_temp = service_temp + u_array[k,list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                if service_temp > result_list[k,i] * final_optimal[k]:
                    failed_num = failed_num + 1
            service_failed_num.append(failed_num)
        return np.array(service_failed_num), np.array(service_failed_num) / base_num     

    def server_load(self, y_array,a):
      
        list_near, list_far, list_assign = self.location_divide()
#        list_near, list_assign = self.find_relative_loca()
        final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
        a_use = 1 / a
        final_result = []
        u_result = np.zeros((time_part,base_num,len(list_near[0])))
        
        for i in range(time_part):
            model_placement = gr.Model('Lip_placement')
            final_optimize = model_placement.addVar(vtype='S', lb=0)          
            u = model_placement.addVars(base_num, len(list_near[0]), vtype='S', lb=0)                
            model_placement.setObjective(final_optimize, gr.GRB.MAXIMIZE)
            model_placement.addConstr(final_optimize == sum(u[k,j] * final_rebuild[i][k] for j in range(len(list_near[0])) 
                                      for k in range(base_num))
            )
            model_placement.addConstrs(
                    sum(u[k,j] for j in range(len(list_near[0]))) <= a_use
                    for k in range(base_num)
            )
            model_placement.addConstrs(
                    sum(u[list_assign[k][j],list_near[list_assign[k][j]].index(k)] * final_rebuild[i][list_assign[k][j]] 
                    for j in range(len(list_assign[k]))) <= y_array[i,k] 
                    for k in range(base_num)
             ) 
            model_placement.optimize() 
       
            if model_placement.Status == gr.GRB.OPTIMAL:
                list_var = [] 
                for var in model_placement.getVars():
                    list_var.append(var.x)
                final_optimal = list_var[0]
                u_list = np.array(list_var[1 : ])
                u_array = u_list.reshape((base_num, len(list_near[0])))
                final_result.append(final_optimal)
                u_result[i,:,:] = u_array
        return final_result, u_result, final_rebuild
    
    def server_hour_overload(self, y_array, a, hour_index, overload_rate, rand_num):
      
        list_near, list_far, list_assign = self.location_divide()
#        list_near, list_assign = self.find_relative_loca()
        final_rebuild = self.sum_distribution()
        origin_data = np.array(final_rebuild[hour_index])
        index_list = []
        for i in range(len(origin_data)):
            if origin_data[i] > 0:
                index_list.append(i)
        rand_index = np.random.randint(0,len(index_list),size = rand_num,dtype='int')
        index_list = np.array(index_list)
        final_index = index_list[rand_index]
        origin_data[final_index] = origin_data[final_index] * overload_rate
        base_num = len(list_near)
        list_var = [] 
        a_use = 1 / a
        
        model_placement = gr.Model('Lip_placement')
        final_optimize = model_placement.addVar(vtype='S', lb=0)          
        u = model_placement.addVars(base_num, len(list_near[0]), vtype='S', lb=0)                
        model_placement.setObjective(final_optimize, gr.GRB.MAXIMIZE)
        model_placement.addConstr(final_optimize == sum(u[k,j] * origin_data[k] for j in range(len(list_near[0])) 
                                 for k in range(base_num))
        )
        model_placement.addConstrs(
                sum(u[k,j] for j in range(len(list_near[0]))) <= a_use
                for k in range(base_num)
        )
        model_placement.addConstrs(
                sum(u[list_assign[k][j],list_near[list_assign[k][j]].index(k)] * origin_data[list_assign[k][j]] 
                for j in range(len(list_assign[k]))) <= y_array[hour_index,k] 
                for k in range(base_num)
        ) 
        model_placement.optimize() 
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            list_var = [] 
            for var in model_placement.getVars():
                list_var.append(var.x)
            final_optimal = list_var[0]
            u_list = np.array(list_var[1 : ])
            u_array = u_list.reshape((base_num, len(list_near[0])))
        return final_optimal, u_array, origin_data
    
    def failed_statistic(self, y_array, max_rate, hour_index, overload_rate, rand_num):
        final_result, u_result, final_rebuild = self.server_load(y_array, max_rate)
        failed_result = [0 for i in range(len(u_result))]
        service_num = [0 for i in range(len(u_result))]
        for i in range(len(u_result)):
            load_temp = final_rebuild[i]
            for j in range(len(u_result[0,:,:])):
                if load_temp[j] != 0:
                     sum_temp = np.sum(u_result[i,j,:])
                     service_num[i] += 1
                     if sum_temp < 1 / max_rate:
                         failed_result[i] += 1
        return failed_result,u_result,service_num
    
    def failed_statistic_hour(self, y_array, max_rate, hour_index, overload_rate, rand_num):
        final_optimal, u_array, origin_data = self.server_hour_overload( y_array, max_rate, hour_index, overload_rate, rand_num)
        failed_result = 0
        service_num =0 
        for j in range(len(u_array)):
            if origin_data[j] != 0:
                sum_temp = np.sum(u_array[j,:])
                service_num += 1
                if sum_temp < 1 / max_rate:
                    failed_result += 1
        return failed_result,u_array,service_num
    
    
    def server_base_load(self, y_array,a):
      
        list_near, list_far, list_assign = self.location_divide()
#        list_near, list_assign = self.find_relative_loca()
        final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
        a_use = 1 / a
        u_result = np.zeros((time_part,base_num,len(list_near[0])))
        server_use_array = np.zeros((time_part , base_num))
        server_diff_array = np.zeros((time_part , base_num))
        
        for i in range(time_part):
            model_placement = gr.Model('Lip_placement')
            server_diff = model_placement.addVars(base_num, vtype='S')  
            server_usage = model_placement.addVars(base_num, vtype='S', lb=0)          
            u = model_placement.addVars(base_num, len(list_near[0]), vtype='S', lb=0)                
            model_placement.setObjective(sum(y_array[i,j] * server_diff[j] for j in range(base_num)), gr.GRB.MINIMIZE)
            model_placement.addConstrs(
                    server_diff[j] >= a_use - server_usage[j] for j in range(base_num)
            )
            model_placement.addConstrs(
                    server_diff[j] >= server_usage[j] - a_use for j in range(base_num)
            )
            model_placement.addConstrs(
                    sum(u[k,j] for j in range(len(list_near[0]))) == a_use
                    for k in range(base_num)
            )
            model_placement.addConstrs(
                    sum(u[list_assign[k][j],list_near[list_assign[k][j]].index(k)] * final_rebuild[i][list_assign[k][j]] 
                    for j in range(len(list_assign[k]))) <= y_array[i,k] * server_usage[k]
                    for k in range(base_num)
             ) 
            model_placement.optimize() 
       
            if model_placement.Status == gr.GRB.OPTIMAL:
                list_var = [] 
                for var in model_placement.getVars():
                    list_var.append(var.x)
                server_dif = list_var[0 : base_num]
                server_use = list_var[base_num : 2 * base_num]
                u_list = np.array(list_var[2 * base_num : ])
                u_array = u_list.reshape((base_num, len(list_near[0])))
                u_result[i,:,:] = u_array
                server_use_array[i,:] = server_use
                server_diff_array[i,:] = server_dif
        return server_use_array, server_diff_array, u_result
    
    def server_overload(self, y_array, a):
        server_use_array,server_diff_array , u_result = self.server_base_load(y_array,a)
        list_failed = [0 for i in range(len(server_use_array))]
        for i in range(len(server_use_array)):
            for j in range(len(server_use_array[0,:])):
                if server_use_array[i,j] > 1 / a:
                    list_failed[i] += (server_use_array[i,j] - 1 / a)* y_array[i,j]
        return list_failed
                    
        
        
    
    
    
    
    